{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Guidance for training a model with your own data"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d3fbf38ad2f4ccaa"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. Import the necessary packages"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6e474201f7f64ab4"
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from exp.exp_custom import Exp_Custom"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-06-13T12:01:15.177812Z",
     "start_time": "2024-06-13T12:01:14.568408500Z"
    }
   },
   "id": "356a8bc9c1e1349c"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2. Define the hyperparameters"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f4b69019515b59d7"
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Args in experiment:\n",
      "Namespace(activation='gelu', batch_size=16, checkpoints='./checkpoints/', d_core=64, d_ff=128, d_model=128, data='ETTm1', data_path='ETTm1.csv', dropout=0.0, e_layers=2, features='M', freq='h', gpu='0', learning_rate=0.0003, lradj='cosine', model='SOFTS', num_workers=0, patience=3, pred_len=96, root_path='./dataset/ETT-small/', save_model=True, seq_len=96, train_epochs=50, use_gpu=True, use_norm=True)\n"
     ]
    }
   ],
   "source": [
    "# fix seed for reproducibility\n",
    "fix_seed = 2021\n",
    "random.seed(fix_seed)\n",
    "torch.manual_seed(fix_seed)\n",
    "np.random.seed(fix_seed)\n",
    "torch.set_num_threads(6)\n",
    "\n",
    "# basic config\n",
    "config = {\n",
    "    # dataset settings\n",
    "    'root_path': './dataset/ETT-small/',\n",
    "    'data_path': 'ETTm1.csv',\n",
    "    'data': 'ETTm1',\n",
    "    'features': 'M',\n",
    "    'freq': 'h',\n",
    "    'seq_len': 96,\n",
    "    'pred_len': 96,\n",
    "    # model settings\n",
    "    'model': 'SOFTS',\n",
    "    'checkpoints': './checkpoints/',\n",
    "    'd_model': 128,\n",
    "    'd_core': 64,\n",
    "    'd_ff': 128,\n",
    "    'e_layers': 2,\n",
    "    'learning_rate': 0.0003,\n",
    "    'lradj': 'cosine',\n",
    "    'train_epochs': 50,\n",
    "    'patience': 3,\n",
    "    'batch_size': 16,\n",
    "    'dropout': 0.0,\n",
    "    'activation': 'gelu',\n",
    "    'use_norm': True,\n",
    "    # system settings\n",
    "    'num_workers': 0,\n",
    "    'use_gpu': True,\n",
    "    'gpu': '0',\n",
    "    'save_model': True,\n",
    "}\n",
    "\n",
    "parser = argparse.ArgumentParser(description='SOFTS')\n",
    "args = parser.parse_args([])\n",
    "args.__dict__.update(config)\n",
    "args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False\n",
    "\n",
    "print('Args in experiment:')\n",
    "print(args)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-06-13T12:01:15.242241600Z",
     "start_time": "2024-06-13T12:01:15.178809900Z"
    }
   },
   "id": "e886ee3de8bbb1e7"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3. Prepare the dataset\n",
    "Organize your data in the following format:\n",
    "- The dataset should be a csv file.\n",
    "- If there is a time feature, the first column contains timestamps in the format 'YYYY-MM-DD HH:MM:SS'. If there's no time feature, the dataset starts directly with the features.\n",
    "- If the parameter `features` is 'M', the following columns are both the features and the targets. If `features` is 'MS', the following columns are the features, and the last column is the target."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "171fc13ff2726f"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                  date   HUFL   HULL   MUFL   MULL   LUFL   LULL         OT\n",
      "0  2016-07-01 00:00:00  5.827  2.009  1.599  0.462  4.203  1.340  30.531000\n",
      "1  2016-07-01 00:15:00  5.760  2.076  1.492  0.426  4.264  1.401  30.459999\n",
      "2  2016-07-01 00:30:00  5.760  1.942  1.492  0.391  4.234  1.310  30.038000\n",
      "3  2016-07-01 00:45:00  5.760  1.942  1.492  0.426  4.234  1.310  27.013000\n",
      "4  2016-07-01 01:00:00  5.693  2.076  1.492  0.426  4.142  1.371  27.787001\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "data = pd.read_csv(os.path.join(args.root_path, args.data_path))\n",
    "print(data.head())\n",
    "\n",
    "# split data\n",
    "train_data = data.iloc[: 12 * 30 * 24 * 4]\n",
    "vali_data = data.iloc[12 * 30 * 24 * 4 - args.seq_len: 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4]\n",
    "test_data = data.iloc[12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - args.seq_len: 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]\n",
    "\n",
    "# optional: scale data\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "if 'date' in train_data.columns:\n",
    "    scaler.fit(train_data.iloc[:, 1:])\n",
    "    train_data.iloc[:, 1:] = scaler.transform(train_data.iloc[:, 1:])\n",
    "    vali_data.iloc[:, 1:] = scaler.transform(vali_data.iloc[:, 1:])\n",
    "    test_data.iloc[:, 1:] = scaler.transform(test_data.iloc[:, 1:])\n",
    "else:\n",
    "    scaler.fit(train_data.iloc[:, :])\n",
    "    train_data.iloc[:, :] = scaler.transform(train_data.iloc[:, :])\n",
    "    vali_data.iloc[:, :] = scaler.transform(vali_data.iloc[:, :])\n",
    "    test_data.iloc[:, :] = scaler.transform(test_data.iloc[:, :])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-06-13T12:01:15.594744200Z",
     "start_time": "2024-06-13T12:01:15.246228700Z"
    }
   },
   "id": "8bc7a801398c68de"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 4. Train and Evaluate the model\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c04cef58eb7d57d6"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Use GPU: cuda:0\n",
      ">>>>>>>start training : ETTm1_SOFTS_96_96>>>>>>>>>>>>>>>>>>>>>>>>>>\n",
      "\titers: 100, epoch: 1 | loss: 0.4506244\n",
      "\tspeed: 0.0274s/iter; left time: 2937.2337s\n",
      "\titers: 200, epoch: 1 | loss: 0.2323690\n",
      "\tspeed: 0.0056s/iter; left time: 600.3217s\n",
      "\titers: 300, epoch: 1 | loss: 0.4256146\n",
      "\tspeed: 0.0054s/iter; left time: 573.5607s\n",
      "\titers: 400, epoch: 1 | loss: 0.2308036\n",
      "\tspeed: 0.0056s/iter; left time: 602.7014s\n",
      "\titers: 500, epoch: 1 | loss: 0.2784870\n",
      "\tspeed: 0.0054s/iter; left time: 574.7486s\n",
      "\titers: 600, epoch: 1 | loss: 0.3876713\n",
      "\tspeed: 0.0053s/iter; left time: 570.3443s\n",
      "\titers: 700, epoch: 1 | loss: 0.2946256\n",
      "\tspeed: 0.0054s/iter; left time: 577.6516s\n",
      "\titers: 800, epoch: 1 | loss: 0.2888232\n",
      "\tspeed: 0.0050s/iter; left time: 536.3711s\n",
      "\titers: 900, epoch: 1 | loss: 0.2390691\n",
      "\tspeed: 0.0052s/iter; left time: 554.7302s\n",
      "\titers: 1000, epoch: 1 | loss: 0.2671814\n",
      "\tspeed: 0.0051s/iter; left time: 545.2673s\n",
      "\titers: 1100, epoch: 1 | loss: 0.2918743\n",
      "\tspeed: 0.0050s/iter; left time: 531.0222s\n",
      "\titers: 1200, epoch: 1 | loss: 0.1836184\n",
      "\tspeed: 0.0051s/iter; left time: 541.6807s\n",
      "\titers: 1300, epoch: 1 | loss: 0.2745768\n",
      "\tspeed: 0.0052s/iter; left time: 553.6080s\n",
      "\titers: 1400, epoch: 1 | loss: 0.2791349\n",
      "\tspeed: 0.0054s/iter; left time: 569.8933s\n",
      "\titers: 1500, epoch: 1 | loss: 0.2686602\n",
      "\tspeed: 0.0069s/iter; left time: 726.2993s\n",
      "\titers: 1600, epoch: 1 | loss: 0.4751884\n",
      "\tspeed: 0.0059s/iter; left time: 624.9110s\n",
      "\titers: 1700, epoch: 1 | loss: 0.2222381\n",
      "\tspeed: 0.0055s/iter; left time: 584.4944s\n",
      "\titers: 1800, epoch: 1 | loss: 0.2600937\n",
      "\tspeed: 0.0053s/iter; left time: 558.6933s\n",
      "\titers: 1900, epoch: 1 | loss: 0.3446293\n",
      "\tspeed: 0.0052s/iter; left time: 547.7753s\n",
      "\titers: 2000, epoch: 1 | loss: 0.2382944\n",
      "\tspeed: 0.0054s/iter; left time: 574.5461s\n",
      "\titers: 2100, epoch: 1 | loss: 0.2264607\n",
      "\tspeed: 0.0059s/iter; left time: 616.7093s\n",
      "Epoch: 1 cost time: 13.926933288574219\n",
      "Validation loss decreased (inf --> 0.420679).  Saving model ...\n",
      "Epoch: 1, Steps: 2149 | Train Loss: 0.2933352 Vali Loss: 0.4206790 Test Loss: 0.3356791\n",
      "Updating learning rate to 0.0002997040092642407\n",
      "\titers: 100, epoch: 2 | loss: 0.2879750\n",
      "\tspeed: 0.0256s/iter; left time: 2698.3506s\n",
      "\titers: 200, epoch: 2 | loss: 0.2676999\n",
      "\tspeed: 0.0050s/iter; left time: 522.5153s\n",
      "\titers: 300, epoch: 2 | loss: 0.3097526\n",
      "\tspeed: 0.0050s/iter; left time: 528.7812s\n",
      "\titers: 400, epoch: 2 | loss: 0.2897035\n",
      "\tspeed: 0.0050s/iter; left time: 525.8588s\n",
      "\titers: 500, epoch: 2 | loss: 0.2596558\n",
      "\tspeed: 0.0051s/iter; left time: 532.9598s\n",
      "\titers: 600, epoch: 2 | loss: 0.3919277\n",
      "\tspeed: 0.0052s/iter; left time: 540.5045s\n",
      "\titers: 700, epoch: 2 | loss: 0.2498662\n",
      "\tspeed: 0.0050s/iter; left time: 526.2182s\n",
      "\titers: 800, epoch: 2 | loss: 0.3093124\n",
      "\tspeed: 0.0065s/iter; left time: 676.0323s\n",
      "\titers: 900, epoch: 2 | loss: 0.2771933\n",
      "\tspeed: 0.0060s/iter; left time: 626.2461s\n",
      "\titers: 1000, epoch: 2 | loss: 0.4784857\n",
      "\tspeed: 0.0054s/iter; left time: 563.8793s\n",
      "\titers: 1100, epoch: 2 | loss: 0.4175684\n",
      "\tspeed: 0.0053s/iter; left time: 550.7327s\n",
      "\titers: 1200, epoch: 2 | loss: 0.3272308\n",
      "\tspeed: 0.0054s/iter; left time: 557.4052s\n",
      "\titers: 1300, epoch: 2 | loss: 0.2790713\n",
      "\tspeed: 0.0052s/iter; left time: 538.4292s\n",
      "\titers: 1400, epoch: 2 | loss: 0.2294715\n",
      "\tspeed: 0.0051s/iter; left time: 527.5278s\n",
      "\titers: 1500, epoch: 2 | loss: 0.2734144\n",
      "\tspeed: 0.0060s/iter; left time: 626.4313s\n",
      "\titers: 1600, epoch: 2 | loss: 0.2556303\n",
      "\tspeed: 0.0054s/iter; left time: 559.7843s\n",
      "\titers: 1700, epoch: 2 | loss: 0.2353841\n",
      "\tspeed: 0.0054s/iter; left time: 556.9034s\n",
      "\titers: 1800, epoch: 2 | loss: 0.2425886\n",
      "\tspeed: 0.0051s/iter; left time: 527.8525s\n",
      "\titers: 1900, epoch: 2 | loss: 0.2627399\n",
      "\tspeed: 0.0055s/iter; left time: 570.4897s\n",
      "\titers: 2000, epoch: 2 | loss: 0.2080787\n",
      "\tspeed: 0.0056s/iter; left time: 574.0719s\n",
      "\titers: 2100, epoch: 2 | loss: 0.2418713\n",
      "\tspeed: 0.0054s/iter; left time: 552.6135s\n",
      "Epoch: 2 cost time: 11.481991529464722\n",
      "Validation loss decreased (0.420679 --> 0.420087).  Saving model ...\n",
      "Epoch: 2, Steps: 2149 | Train Loss: 0.2902201 Vali Loss: 0.4200871 Test Loss: 0.3283640\n",
      "Updating learning rate to 0.0002988172051971717\n",
      "\titers: 100, epoch: 3 | loss: 0.2880904\n",
      "\tspeed: 0.0266s/iter; left time: 2743.1300s\n",
      "\titers: 200, epoch: 3 | loss: 0.2849099\n",
      "\tspeed: 0.0061s/iter; left time: 626.8941s\n",
      "\titers: 300, epoch: 3 | loss: 0.3172294\n",
      "\tspeed: 0.0055s/iter; left time: 565.8185s\n",
      "\titers: 400, epoch: 3 | loss: 0.3554396\n",
      "\tspeed: 0.0050s/iter; left time: 518.2494s\n",
      "\titers: 500, epoch: 3 | loss: 0.3525619\n",
      "\tspeed: 0.0052s/iter; left time: 531.0102s\n",
      "\titers: 600, epoch: 3 | loss: 0.2695810\n",
      "\tspeed: 0.0052s/iter; left time: 534.8646s\n",
      "\titers: 700, epoch: 3 | loss: 0.2115875\n",
      "\tspeed: 0.0052s/iter; left time: 528.7667s\n",
      "\titers: 800, epoch: 3 | loss: 0.3150197\n",
      "\tspeed: 0.0064s/iter; left time: 657.5157s\n",
      "\titers: 900, epoch: 3 | loss: 0.1955295\n",
      "\tspeed: 0.0049s/iter; left time: 502.0578s\n",
      "\titers: 1000, epoch: 3 | loss: 0.3043511\n",
      "\tspeed: 0.0052s/iter; left time: 534.0086s\n",
      "\titers: 1100, epoch: 3 | loss: 0.2362251\n",
      "\tspeed: 0.0049s/iter; left time: 502.6208s\n",
      "\titers: 1200, epoch: 3 | loss: 0.2748100\n",
      "\tspeed: 0.0049s/iter; left time: 497.5573s\n",
      "\titers: 1300, epoch: 3 | loss: 0.2932656\n",
      "\tspeed: 0.0052s/iter; left time: 528.7995s\n",
      "\titers: 1400, epoch: 3 | loss: 0.2300588\n",
      "\tspeed: 0.0049s/iter; left time: 499.1872s\n",
      "\titers: 1500, epoch: 3 | loss: 0.3324521\n",
      "\tspeed: 0.0054s/iter; left time: 547.3210s\n",
      "\titers: 1600, epoch: 3 | loss: 0.2761744\n",
      "\tspeed: 0.0061s/iter; left time: 616.7377s\n",
      "\titers: 1700, epoch: 3 | loss: 0.1916248\n",
      "\tspeed: 0.0051s/iter; left time: 513.1263s\n",
      "\titers: 1800, epoch: 3 | loss: 0.2483542\n",
      "\tspeed: 0.0063s/iter; left time: 633.8177s\n",
      "\titers: 1900, epoch: 3 | loss: 0.2659746\n",
      "\tspeed: 0.0058s/iter; left time: 589.5551s\n",
      "\titers: 2000, epoch: 3 | loss: 0.1941691\n",
      "\tspeed: 0.0056s/iter; left time: 562.9509s\n",
      "\titers: 2100, epoch: 3 | loss: 0.3112955\n",
      "\tspeed: 0.0053s/iter; left time: 532.2117s\n",
      "Epoch: 3 cost time: 11.595311880111694\n",
      "Validation loss decreased (0.420087 --> 0.416802).  Saving model ...\n",
      "Epoch: 3, Steps: 2149 | Train Loss: 0.2737478 Vali Loss: 0.4168020 Test Loss: 0.3290583\n",
      "Updating learning rate to 0.0002973430876093033\n",
      "\titers: 100, epoch: 4 | loss: 0.2467478\n",
      "\tspeed: 0.0261s/iter; left time: 2630.9023s\n",
      "\titers: 200, epoch: 4 | loss: 0.2907024\n",
      "\tspeed: 0.0052s/iter; left time: 521.0184s\n",
      "\titers: 300, epoch: 4 | loss: 0.1952767\n",
      "\tspeed: 0.0052s/iter; left time: 519.3455s\n",
      "\titers: 400, epoch: 4 | loss: 0.2482563\n",
      "\tspeed: 0.0048s/iter; left time: 484.6192s\n",
      "\titers: 500, epoch: 4 | loss: 0.2869221\n",
      "\tspeed: 0.0055s/iter; left time: 551.9477s\n",
      "\titers: 600, epoch: 4 | loss: 0.2432848\n",
      "\tspeed: 0.0051s/iter; left time: 511.8749s\n",
      "\titers: 700, epoch: 4 | loss: 0.2615869\n",
      "\tspeed: 0.0050s/iter; left time: 501.7713s\n",
      "\titers: 800, epoch: 4 | loss: 0.3093148\n",
      "\tspeed: 0.0049s/iter; left time: 492.8370s\n",
      "\titers: 900, epoch: 4 | loss: 0.1970979\n",
      "\tspeed: 0.0049s/iter; left time: 495.3483s\n",
      "\titers: 1000, epoch: 4 | loss: 0.2343193\n",
      "\tspeed: 0.0052s/iter; left time: 517.4276s\n",
      "\titers: 1100, epoch: 4 | loss: 0.2562083\n",
      "\tspeed: 0.0055s/iter; left time: 545.1881s\n",
      "\titers: 1200, epoch: 4 | loss: 0.2822427\n",
      "\tspeed: 0.0054s/iter; left time: 539.5443s\n",
      "\titers: 1300, epoch: 4 | loss: 0.2862186\n",
      "\tspeed: 0.0063s/iter; left time: 626.7433s\n",
      "\titers: 1400, epoch: 4 | loss: 0.2125736\n",
      "\tspeed: 0.0071s/iter; left time: 710.5838s\n",
      "\titers: 1500, epoch: 4 | loss: 0.2899629\n",
      "\tspeed: 0.0051s/iter; left time: 510.1810s\n",
      "\titers: 1600, epoch: 4 | loss: 0.2879072\n",
      "\tspeed: 0.0055s/iter; left time: 545.8996s\n",
      "\titers: 1700, epoch: 4 | loss: 0.2534532\n",
      "\tspeed: 0.0058s/iter; left time: 575.0130s\n",
      "\titers: 1800, epoch: 4 | loss: 0.2427848\n",
      "\tspeed: 0.0051s/iter; left time: 506.5643s\n",
      "\titers: 1900, epoch: 4 | loss: 0.1937304\n",
      "\tspeed: 0.0052s/iter; left time: 511.6990s\n",
      "\titers: 2000, epoch: 4 | loss: 0.2955550\n",
      "\tspeed: 0.0052s/iter; left time: 511.9460s\n",
      "\titers: 2100, epoch: 4 | loss: 0.2177162\n",
      "\tspeed: 0.0053s/iter; left time: 519.9537s\n",
      "Epoch: 4 cost time: 11.526140451431274\n",
      "Validation loss decreased (0.416802 --> 0.407331).  Saving model ...\n",
      "Epoch: 4, Steps: 2149 | Train Loss: 0.2538982 Vali Loss: 0.4073312 Test Loss: 0.3260311\n",
      "Updating learning rate to 0.00029528747416929463\n",
      "\titers: 100, epoch: 5 | loss: 0.2358541\n",
      "\tspeed: 0.0251s/iter; left time: 2482.9074s\n",
      "\titers: 200, epoch: 5 | loss: 0.2369819\n",
      "\tspeed: 0.0059s/iter; left time: 580.4297s\n",
      "\titers: 300, epoch: 5 | loss: 0.2098743\n",
      "\tspeed: 0.0060s/iter; left time: 592.2501s\n",
      "\titers: 400, epoch: 5 | loss: 0.3288191\n",
      "\tspeed: 0.0051s/iter; left time: 504.6819s\n",
      "\titers: 500, epoch: 5 | loss: 0.4056228\n",
      "\tspeed: 0.0051s/iter; left time: 504.0488s\n",
      "\titers: 600, epoch: 5 | loss: 0.2010879\n",
      "\tspeed: 0.0051s/iter; left time: 504.4455s\n",
      "\titers: 700, epoch: 5 | loss: 0.2744911\n",
      "\tspeed: 0.0057s/iter; left time: 561.5379s\n",
      "\titers: 800, epoch: 5 | loss: 0.1945106\n",
      "\tspeed: 0.0055s/iter; left time: 544.0945s\n",
      "\titers: 900, epoch: 5 | loss: 0.1812340\n",
      "\tspeed: 0.0049s/iter; left time: 484.2629s\n",
      "\titers: 1000, epoch: 5 | loss: 0.2142507\n",
      "\tspeed: 0.0063s/iter; left time: 612.0732s\n",
      "\titers: 1100, epoch: 5 | loss: 0.1818398\n",
      "\tspeed: 0.0050s/iter; left time: 487.5360s\n",
      "\titers: 1200, epoch: 5 | loss: 0.2048016\n",
      "\tspeed: 0.0053s/iter; left time: 513.0947s\n",
      "\titers: 1300, epoch: 5 | loss: 0.3843733\n",
      "\tspeed: 0.0051s/iter; left time: 498.8997s\n",
      "\titers: 1400, epoch: 5 | loss: 0.2146500\n",
      "\tspeed: 0.0053s/iter; left time: 520.1366s\n",
      "\titers: 1500, epoch: 5 | loss: 0.1706911\n",
      "\tspeed: 0.0050s/iter; left time: 486.7599s\n",
      "\titers: 1600, epoch: 5 | loss: 0.2873355\n",
      "\tspeed: 0.0052s/iter; left time: 501.6318s\n",
      "\titers: 1700, epoch: 5 | loss: 0.2240126\n",
      "\tspeed: 0.0064s/iter; left time: 620.5410s\n",
      "\titers: 1800, epoch: 5 | loss: 0.2380046\n",
      "\tspeed: 0.0053s/iter; left time: 513.2396s\n",
      "\titers: 1900, epoch: 5 | loss: 0.3126761\n",
      "\tspeed: 0.0050s/iter; left time: 483.1493s\n",
      "\titers: 2000, epoch: 5 | loss: 0.4203324\n",
      "\tspeed: 0.0051s/iter; left time: 494.8558s\n",
      "\titers: 2100, epoch: 5 | loss: 0.2064221\n",
      "\tspeed: 0.0050s/iter; left time: 480.3134s\n",
      "Epoch: 5 cost time: 11.499050617218018\n",
      "EarlyStopping counter: 1 out of 3\n",
      "Epoch: 5, Steps: 2149 | Train Loss: 0.2537079 Vali Loss: 0.4139205 Test Loss: 0.3301461\n",
      "Updating learning rate to 0.00029265847744427303\n",
      "\titers: 100, epoch: 6 | loss: 0.2605379\n",
      "\tspeed: 0.0254s/iter; left time: 2454.5663s\n",
      "\titers: 200, epoch: 6 | loss: 0.1683897\n",
      "\tspeed: 0.0051s/iter; left time: 488.2909s\n",
      "\titers: 300, epoch: 6 | loss: 0.2659087\n",
      "\tspeed: 0.0051s/iter; left time: 493.1072s\n",
      "\titers: 400, epoch: 6 | loss: 0.1988440\n",
      "\tspeed: 0.0057s/iter; left time: 547.4980s\n",
      "\titers: 500, epoch: 6 | loss: 0.2109133\n",
      "\tspeed: 0.0051s/iter; left time: 494.8798s\n",
      "\titers: 600, epoch: 6 | loss: 0.3424521\n",
      "\tspeed: 0.0060s/iter; left time: 581.3801s\n",
      "\titers: 700, epoch: 6 | loss: 0.2633236\n",
      "\tspeed: 0.0052s/iter; left time: 502.7534s\n",
      "\titers: 800, epoch: 6 | loss: 0.3996557\n",
      "\tspeed: 0.0055s/iter; left time: 524.7992s\n",
      "\titers: 900, epoch: 6 | loss: 0.1971629\n",
      "\tspeed: 0.0051s/iter; left time: 484.7656s\n",
      "\titers: 1000, epoch: 6 | loss: 0.1828001\n",
      "\tspeed: 0.0055s/iter; left time: 524.3373s\n",
      "\titers: 1100, epoch: 6 | loss: 0.2320836\n",
      "\tspeed: 0.0050s/iter; left time: 481.6887s\n",
      "\titers: 1200, epoch: 6 | loss: 0.2422465\n",
      "\tspeed: 0.0061s/iter; left time: 581.3289s\n",
      "\titers: 1300, epoch: 6 | loss: 0.2101018\n",
      "\tspeed: 0.0057s/iter; left time: 543.1652s\n",
      "\titers: 1400, epoch: 6 | loss: 0.1730599\n",
      "\tspeed: 0.0048s/iter; left time: 458.0960s\n",
      "\titers: 1500, epoch: 6 | loss: 0.3015074\n",
      "\tspeed: 0.0073s/iter; left time: 692.3004s\n",
      "\titers: 1600, epoch: 6 | loss: 0.2248778\n",
      "\tspeed: 0.0049s/iter; left time: 468.3368s\n",
      "\titers: 1700, epoch: 6 | loss: 0.2006399\n",
      "\tspeed: 0.0054s/iter; left time: 510.3951s\n",
      "\titers: 1800, epoch: 6 | loss: 0.3145465\n",
      "\tspeed: 0.0050s/iter; left time: 476.0263s\n",
      "\titers: 1900, epoch: 6 | loss: 0.2077492\n",
      "\tspeed: 0.0055s/iter; left time: 522.5419s\n",
      "\titers: 2000, epoch: 6 | loss: 0.1928022\n",
      "\tspeed: 0.0074s/iter; left time: 705.1798s\n",
      "\titers: 2100, epoch: 6 | loss: 0.2118652\n",
      "\tspeed: 0.0051s/iter; left time: 482.0632s\n",
      "Epoch: 6 cost time: 11.796943187713623\n",
      "EarlyStopping counter: 2 out of 3\n",
      "Epoch: 6, Steps: 2149 | Train Loss: 0.2381651 Vali Loss: 0.4127691 Test Loss: 0.3371246\n",
      "Updating learning rate to 0.00028946647288323766\n",
      "\titers: 100, epoch: 7 | loss: 0.2260187\n",
      "\tspeed: 0.0275s/iter; left time: 2600.9712s\n",
      "\titers: 200, epoch: 7 | loss: 0.1839273\n",
      "\tspeed: 0.0055s/iter; left time: 520.6293s\n",
      "\titers: 300, epoch: 7 | loss: 0.1823730\n",
      "\tspeed: 0.0052s/iter; left time: 488.1197s\n",
      "\titers: 400, epoch: 7 | loss: 0.1914832\n",
      "\tspeed: 0.0072s/iter; left time: 681.1507s\n",
      "\titers: 500, epoch: 7 | loss: 0.2169700\n",
      "\tspeed: 0.0051s/iter; left time: 478.3268s\n",
      "\titers: 600, epoch: 7 | loss: 0.2720859\n",
      "\tspeed: 0.0056s/iter; left time: 527.7162s\n",
      "\titers: 700, epoch: 7 | loss: 0.1779072\n",
      "\tspeed: 0.0063s/iter; left time: 594.4705s\n",
      "\titers: 800, epoch: 7 | loss: 0.1913169\n",
      "\tspeed: 0.0063s/iter; left time: 587.1492s\n",
      "\titers: 900, epoch: 7 | loss: 0.2312929\n",
      "\tspeed: 0.0052s/iter; left time: 488.9361s\n",
      "\titers: 1000, epoch: 7 | loss: 0.1830410\n",
      "\tspeed: 0.0051s/iter; left time: 479.6148s\n",
      "\titers: 1100, epoch: 7 | loss: 0.3189675\n",
      "\tspeed: 0.0072s/iter; left time: 671.1605s\n",
      "\titers: 1200, epoch: 7 | loss: 0.2285484\n",
      "\tspeed: 0.0051s/iter; left time: 480.2587s\n",
      "\titers: 1300, epoch: 7 | loss: 0.2870460\n",
      "\tspeed: 0.0071s/iter; left time: 661.4106s\n",
      "\titers: 1400, epoch: 7 | loss: 0.1935607\n",
      "\tspeed: 0.0052s/iter; left time: 479.9525s\n",
      "\titers: 1500, epoch: 7 | loss: 0.1845900\n",
      "\tspeed: 0.0050s/iter; left time: 466.7433s\n",
      "\titers: 1600, epoch: 7 | loss: 0.2080109\n",
      "\tspeed: 0.0052s/iter; left time: 485.4923s\n",
      "\titers: 1700, epoch: 7 | loss: 0.1943443\n",
      "\tspeed: 0.0052s/iter; left time: 482.2837s\n",
      "\titers: 1800, epoch: 7 | loss: 0.1713217\n",
      "\tspeed: 0.0066s/iter; left time: 616.3621s\n",
      "\titers: 1900, epoch: 7 | loss: 0.2140364\n",
      "\tspeed: 0.0068s/iter; left time: 625.8200s\n",
      "\titers: 2000, epoch: 7 | loss: 0.3250565\n",
      "\tspeed: 0.0054s/iter; left time: 498.8313s\n",
      "\titers: 2100, epoch: 7 | loss: 0.1656498\n",
      "\tspeed: 0.0050s/iter; left time: 463.7974s\n",
      "Epoch: 7 cost time: 12.472389221191406\n",
      "EarlyStopping counter: 3 out of 3\n",
      "Early stopping\n",
      ">>>>>>>testing : ETTm1_SOFTS_96_96<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n",
      "loading model from ./checkpoints/ETTm1_SOFTS_96_96\\checkpoint.pth\n",
      "mse:0.32603105534740295, mae:0.361756386726899\n"
     ]
    }
   ],
   "source": [
    "Exp = Exp_Custom(args)\n",
    "setting = f'{args.data}_{args.model}_{args.seq_len}_{args.pred_len}'\n",
    "print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))\n",
    "Exp.train(setting=setting, train_data=train_data, vali_data=vali_data, test_data=test_data)\n",
    "print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))\n",
    "Exp.test(setting=setting, test_data=test_data)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-06-13T12:02:52.982915200Z",
     "start_time": "2024-06-13T12:01:15.589759400Z"
    }
   },
   "id": "77857ed9da69bd61"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 5. Get predictions by the model"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6c25a79ea1985454"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading model from ./checkpoints/ETTm1_SOFTS_96_96\\checkpoint.pth\n",
      "(11521, 96, 7)\n"
     ]
    }
   ],
   "source": [
    "# get predictions\n",
    "predictions = Exp.predict(setting=setting, pred_data=test_data)\n",
    "print(predictions.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-06-13T12:02:53.941063400Z",
     "start_time": "2024-06-13T12:02:52.983911800Z"
    }
   },
   "id": "9f0926408d8d19bf"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
